package com.diven.spark.ml.learn.feature

import org.apache.spark.sql.{DataFrame}
import org.apache.spark.sql.functions._
import com.diven.spark.ml.learn.core.BaseTest
import com.diven.spark.ml.learn.core.BaseSpark
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.feature.ChiSqSelector
import org.apache.spark.ml.feature.ChiSqSelectorModel

object ChiSqSelectorTest extends BaseTest {
    
    def apply(): BaseSpark = new ChiSqSelectorTest()
    
}

class ChiSqSelectorTest extends BaseSpark{
    
     override def execute(dataFrame: DataFrame) = {
        //特征名称
        var features = Array("weight", "height", "age")
        //字段转换成特征向量,并切分为训练集合测试集
        var vectorSplitDatas = new VectorAssembler()
        .setInputCols(features)
        .setOutputCol("vector_features")
        .transform(dataFrame.select("id", "weight", "height", "age", "qualified"))
        .randomSplit(Array(0.8, 0.2))
        //使用卡方检验，
        val selector = new ChiSqSelector()
        .setFeaturesCol("vector_features")       //特征矢量
        .setOutputCol("selected_features")       //降维后的特征矢量
        .setLabelCol("qualified")                //标签列
        .setNumTopFeatures(2)                    //将原始特征向量（特征数为3）降维（特征数为2）
        .setSelectorType("numTopFeatures")       //numTopFeatures,percentile,fpr
        .setFpr(0.05)                            //仅当SelectorType = "fpr"
        .setPercentile(0)                        //仅当selectorType = "percentile"
        //训练模型
        var model: ChiSqSelectorModel = selector.fit(vectorSplitDatas.apply(0))
        //模型选择的特征
        model.selectedFeatures.foreach(item => print(features.apply(item) + ","))
        //测试模型
        model.transform(vectorSplitDatas.apply(1)).show(10)
     }
}
